BOJ_2261_가장 가까운 두 점

Closest Pair를 사용하는 문제
어떻게 보면 Closest Pair의 전형이라고 볼 수 있지만, 알고리즘 자체가 생소하기 때문에 구상을 하긴 어려웠다

알고리즘을 공부하고 다시 풀어봤다


import sys  
  
  
def dist(p1, p2):  
    return (p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2  
  
  
def brute_force(points, start, end):  
    min_dist = float('inf')  
    for i in range(start, end):  
        for j in range(i + 1, end + 1):  
            min_dist = min(min_dist, dist(points[i], points[j]))  
    return min_dist  
  
# 병합 시 y를 기준으로 병합하기  
def merge_y_sorted(left, right):  
    result = []  
    i, j = 0, 0  
    while i < len(left) and j < len(right):  
        if left[i][1] <= right[j][1]:  
            result.append(left[i])  
            i += 1  
        else:  
            result.append(right[j])  
            j += 1  
    result.extend(left[i:])  
    result.extend(right[j:])  
    return result  
  
  
def strip_closest(strip, d):  
    min_dist = d  
    for i in range(len(strip)):  
        j = i + 1  
        while j < len(strip) and (strip[j][1] - strip[i][1]) ** 2 < min_dist:  
            min_dist = min(min_dist, dist(strip[i], strip[j]))  
            j += 1  
    return min_dist  
  
  
def closest_pair(points_x, points_y):  
    n = len(points_x)  
    if n <= 3:  
        return brute_force(points_x, 0, n - 1)  
  
    mid = n // 2  
    mid_point = points_x[mid]  
  
    left_x = points_x[:mid]  
    right_x = points_x[mid:]  
  
    left_y = []  
    right_y = []  
    for point in points_y:  
        if point[0] <= mid_point[0]:  
            left_y.append(point)  
        else:  
            right_y.append(point)  
  
    left_min = closest_pair(left_x, left_y)  
    right_min = closest_pair(right_x, right_y)  
  
    d = min(left_min, right_min)  
  
    strip = [p for p in points_y if abs(p[0] - mid_point[0]) ** 2 < d]  
  
    return min(d, strip_closest(strip, d))  
  
  
n = int(input())  
points = [list(map(int, input().split())) for _ in range(n)]  
  
points_x = sorted(points)  
points_y = sorted(points, key=lambda p: p[1])  
  
print(closest_pair(points_x, points_y))